#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
# Copyright 2025 The Lynx Authors. All rights reserved.
# Licensed under the Apache License Version 2.0 that can be found in the
# LICENSE file in the root directory of this source tree.


import os
import sys
import yaml
import argparse
from jni_generator import GenerateJNIHeader, Options

def parse_yaml(root_path, input_file_path):
  def join_path(dir_info, node):
    if dir_info == '':
      return node
    key, dir = dir_info.split(':')
    if key != '' and key in node.keys() and isinstance(node[key], str):
      node[key] = os.path.join(dir, node[key])
    return node

  # define !include constructor
  # usage: !include file_name | key1.key2 & key3:parent_path
  # note: key3 is the key value of elements inside key1.key2
  def include_constructor(loader, node):
      file_name, key = loader.construct_scalar(node).replace(' ', '').split('|')
      dir_info = ''
      if '&' in key:
        key, dir_info = key.split('&')
      print(f"Processing !include directive - Extracted file name: {file_name}, key: {key}, directory: {dir_info}")
      key_list = key.split('.')
      file_path = os.path.join(root_path, file_name)
      with open(file_path, 'r') as f:
          data = yaml.full_load(f)
          for key in key_list:
            data = data.get(key, None)
          if dir_info != '':
            data = [join_path(dir_info, d) for d in data]
          return data

  # register !include constructor
  yaml.add_constructor('!include', include_constructor)
  with open(input_file_path, 'r') as file:
    data = yaml.full_load(file)
    return data

so_load_file_template = """
// This file is autogenerated.

// Copyright 2025 The Lynx Authors. All rights reserved.
// Licensed under the Apache License Version 2.0 that can be found in the
// LICENSE file in the root directory of this source tree.

CUSTOM_HEADERS
// AUTO_GENERATED_INCLUDE_HEADERS_START
AUTO_GENERATED_INCLUDE_HEADERS_CONTENT
// AUTO_GENERATED_INCLUDE_HEADERS_END

NAMESPACE_START
extern "C" JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void* reserved) {
  lynx::base::android::InitVM(vm);
  JNIEnv* env = lynx::base::android::AttachCurrentThread();
AUTO_GENERATED_REGISTER_METHODS
  return JNI_VERSION_1_6;
}

NAMESPACE_END
"""

build_gn_template = """# This file is autogenerated.

# Copyright 2025 The Lynx Authors. All rights reserved.
# Licensed under the Apache License Version 2.0 that can be found in the
# LICENSE file in the root directory of this source tree.
CUSTOM_HEADERS
TEMPLATE_NAME("build") {
  sources = [SOURCE_FILES]
CONFIG_TARGETS}
"""

register_header_template = """
// This file is autogenerated for
//     JAVA_FILE_PATH

#ifndef HEADER_GUARD
#define HEADER_GUARD

#include <jni.h>

NAMESPACE_START

bool FUNCTION_NAME(JNIEnv* env);

NAMESPACE_END

#endif  // HEADER_GUARD

"""
def get_relative_path(path1, path2):
    return os.path.relpath(path2, start=os.path.dirname(path1))

def convert_to_relative_path(root_path, src_path, des_path):
    return get_relative_path(os.path.join(root_path, src_path), os.path.join(root_path, des_path))

def get_namespace(configs):
  jni_namespaces = configs.get('namespaces', ['lynx'])  + ['jni']
  return jni_namespaces

def get_namespace_guard(configs):
  namespaces = get_namespace(configs)
  namespace_start_str = ''
  namespace_end_str = ''
  for namespace in namespaces:
    namespace_start_str =  namespace_start_str + 'namespace {} '.format(namespace) + '{'
    namespace_end_str = '}' + '  // namespace {}\n'.format(namespace) + namespace_end_str
  return namespace_start_str, namespace_end_str

def format_file(file_path):
  format_cmd = 'clang-format -i {}'.format(file_path)
  if file_path.endswith('.gn'):
    lynx_tools_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
    gn_cmd = os.path.join(lynx_tools_path, 'gn_tools', 'gn_wrapper.py')
    format_cmd = 'python3 {} format {}'.format(gn_cmd, file_path)
  os.system(format_cmd)

def write_content_to_file(file_path, content):
  if os.path.exists(file_path):
    with open(file_path, 'r') as file:
      old_content = file.read()
      if content == old_content:
        return
  with open(file_path, 'w') as file:
    file.write(content)
  format_file(file_path)

def generate_register_header(java_file, function_name, register_header_path, namespace_start_str, namespace_end_str):
  # generate XXX_register_jni.h which contains RegisterJNIForXXX method.
  guard_string = java_file.replace('/', '_')
  guard_string = guard_string.split('.')[0].upper()
  guard_string = guard_string + '_REGISTER_JNI_H'
  header_filled_template = register_header_template.replace('HEADER_GUARD', guard_string)
  header_filled_template = header_filled_template.replace('JAVA_FILE_PATH', java_file)
  header_filled_template = header_filled_template.replace('FUNCTION_NAME', function_name)
  header_filled_template = header_filled_template.replace('NAMESPACE_START', namespace_start_str)
  header_filled_template = header_filled_template.replace('NAMESPACE_END', namespace_end_str)
  write_content_to_file(register_header_path, header_filled_template)

def append_content_if_changed(file_path, start_flag, end_flag, new_content_list):
  # Find content that starts at start_flag and ends at end_flag
  file_lines = []
  matched_lines = []
  start = False
  end = False
  file_start_num = 0
  file_end_num = 0
  with open(file_path, 'r') as file:
    index = 0
    while True:
        line = file.readline()
        if not line:
          break
        file_lines.append(line)
        index = index + 1
        if not start and start_flag in line:
          start = True
          file_start_num = index
          continue
        if not end and end_flag in line:
          end = True
          file_end_num = index
        if (not start) or (start and end):
          continue
        matched_lines.append(line)

  # Find where the new content begins in old content.
  found = False
  first_line = new_content_list[0]
  start_line_num = 0
  end_line_num = 0
  for index, line in enumerate(matched_lines):
    if first_line in line:
      found = True
      start_line_num = index
      
    if found and line == '\n':
      end_line_num = index
      break

  # Insert new content to file
  changed = False
  if not found:
    changed = True
  elif found and end_line_num - start_line_num != len(new_content_list):
    changed = True
    # Remove old content
    start = file_start_num + start_line_num
    end = file_start_num + end_line_num + 1
    print('delete line {} to {}'.format(start, end))
    del file_lines[start:end]
  else:
    for index, line in enumerate(new_content_list):
      if not line == matched_lines[index + start_line_num]:
        changed = True
        # Remove old content
        start = file_start_num + start_line_num
        end = file_start_num + end_line_num + 1
        print('delete line {} to {}'.format(start, end))
        del file_lines[start:end]
        break
  
  if changed:
    # Write new content to file
    print('file changed')
    new_content_string = ''.join(new_content_list)
    file_lines.insert(file_start_num, new_content_string + '\n')
    with open(file_path, 'w') as file:
      file.writelines(file_lines)

# Write content to XXXSoLoad.cc
def append_content_to_so_registry(so_configs, include_headers, register_methods):
  cpp_output_path = so_configs.get('output_path', '')
  custom_headers = so_configs.get('custom_headers', [])
  
  so_file_str = so_load_file_template

  # custom headers
  header_str = ''
  for custom_header in custom_headers:
    header_str = header_str + '#include "{}"\n'.format(custom_header)
  so_file_str = so_file_str.replace('CUSTOM_HEADERS', header_str)
  
  # namespaces
  namespace_start_str, namespace_end_str = get_namespace_guard(so_configs)
  so_file_str = so_file_str.replace('NAMESPACE_START', namespace_start_str)
  so_file_str = so_file_str.replace('NAMESPACE_END', namespace_end_str)

  # include headers
  include_headers_str = ''
  for include_header in include_headers:
    # whether macro is defined?
    if len(include_header[1]) != 0:
      include_headers_str += f'#if {include_header[1]}\n'
      include_headers_str += f'{include_header[0]}\n'
      include_headers_str += '#endif\n'
    else:
      include_headers_str += f'{include_header[0]}\n'
  so_file_str = so_file_str.replace('AUTO_GENERATED_INCLUDE_HEADERS_CONTENT', include_headers_str)

  # register methods
  register_methods_str = ''
  for register_method in register_methods:
    new_line = '  ' + register_method[0] + '\n'
    if len(register_method[1]) != 0:
      register_methods_str += f'#if {register_method[1]}\n'
      register_methods_str += new_line
      register_methods_str += '#endif\n'
    else:
      register_methods_str += new_line
  so_file_str = so_file_str.replace('AUTO_GENERATED_REGISTER_METHODS', register_methods_str)

  directory = os.path.dirname(cpp_output_path)
  if not os.path.exists(directory):
    os.makedirs(directory)
  write_content_to_file(cpp_output_path, so_file_str)
  
# Write content to BUILD.gn
def append_files_to_gn(root_path, gn_configs, gn_files):
  gn_file_path = gn_configs.get('output_path', '')

  custom_headers = gn_configs.get('custom_headers', [])
  template_default_name = 'source_set'
  template_name = gn_configs.get('template_name', template_default_name)
  dep_configs = gn_configs.get('custom_configs', [])
  gn_file_str = build_gn_template

  header_str = ''
  if len(custom_headers) > 0:
    header_str = '\n'
  for custom_header in custom_headers:
    custom_header = convert_to_relative_path(root_path, gn_file_path, custom_header)
    header_str = header_str + f'import("{custom_header}")\n'
  gn_file_str = gn_file_str.replace('CUSTOM_HEADERS', header_str)

  gn_file_str = gn_file_str.replace('TEMPLATE_NAME', template_name)

  source_list_str = ''
  if len(gn_files) > 0:
    gn_sources_list = [f'    "{file}",\n' for file in gn_files]
    source_list_str = f'\n{"".join(gn_sources_list)}  '
  gn_file_str = gn_file_str.replace('SOURCE_FILES', source_list_str)

  config_list_str = ''
  if len(dep_configs) > 0:
    configs_template = '  configs = [CONFIGS]\n'
    if template_name == template_default_name:
      configs_template = '  configs += [CONFIGS]\n'
    dep_config_list = [f'    "{convert_to_relative_path(root_path, gn_file_path, config)}",\n' for config in dep_configs]
    config_list_str = configs_template.replace('CONFIGS', f'\n{"".join(dep_config_list)}  ')
  gn_file_str = gn_file_str.replace('CONFIG_TARGETS', config_list_str)

  write_content_to_file(gn_file_path, gn_file_str)

def get_modification_time_str(root_path, file_name):
    file_path = os.path.join(root_path, file_name)
    if not os.path.exists(file_path):
      return ''
    timestamp = os.path.getmtime(file_path)
    return f'{file_path}:{timestamp}#@'

def get_java_class_list(java_class):
  java_path_raw = java_class.get('java', '')
  if type(java_path_raw) != set and type(java_path_raw) != list:
    java_path_list = [java_path_raw]
  else:
    java_path_list = java_path_raw
  return java_path_list
def should_skip_generate_jni(root_path, jni_configs_file, jni_classes, jni_output_path, gn_file_path, so_load_file_path, use_base_jni_utils_header):
  # timestamps of files:
  # input: current script / jni_generator.py / config yaml script / java files
  # output: so load file / jni files / gn file

  timestamps = ""

  # input file timestamps
  current_script_name = os.path.relpath(os.path.abspath(__file__), root_path)
  timestamps += get_modification_time_str(root_path, current_script_name)

  jni_generator_name = os.path.join(os.path.dirname(current_script_name), 'jni_generator.py')
  timestamps += get_modification_time_str(root_path, jni_generator_name)
  
  timestamps += get_modification_time_str(root_path, jni_configs_file)
  
  for java_class in jni_classes:
    java_path_list = get_java_class_list(java_class)
    for java_path in java_path_list:
      timestamps += get_modification_time_str(root_path, java_path)

  # output file timestamps
  timestamps += get_modification_time_str(root_path, jni_output_path)
  timestamps += get_modification_time_str(root_path, gn_file_path)
  timestamps += get_modification_time_str(root_path, so_load_file_path)

  jni_file_dir = os.path.join(root_path, jni_output_path)
  if os.path.exists(jni_file_dir):
    for root, dirs, files in os.walk(jni_file_dir):
      for file in files:
        if file == 'timestamps.txt':
          continue
        file_name = os.path.join(jni_output_path, file)
        timestamps += get_modification_time_str(root_path, file_name)
  else:
    os.makedirs(jni_file_dir, exist_ok=True)
  
  timestamps += f'use_base_jni_utils_header:{use_base_jni_utils_header}'

  # write timestamps
  timestamps_file = os.path.join(jni_file_dir, 'timestamps.txt')

  if os.path.exists(timestamps_file):
    with open(timestamps_file, 'r') as file:
      old_timestamps = file.readlines()[0]
      file.close()
    if timestamps == old_timestamps:
      return True
  

  with open(timestamps_file, 'w') as file:
    file.write(timestamps)
    file.close()

  return False

def get_jni_classes(jni_classes_list):
  classes = []
  for jni_class in jni_classes_list:
    if type(jni_class) == list:
      classes.extend(get_jni_classes(jni_class))
    else:
      classes.append(jni_class)
  return classes

def generate_files(root_path, jni_configs_file, use_base_jni_utils_header):
  # Parse jni_files yaml file to a map
  jni_configs_abs_path = os.path.join(root_path, jni_configs_file)
  jni_configs = parse_yaml(root_path, jni_configs_abs_path)

  # Read config from yaml map
  jni_classes_configs = jni_configs.get('jni_class_configs', [])
  jni_register_configs = jni_configs.get('jni_register_configs', {})
  gn_configs = jni_configs.get('gn_configs', {})

  jni_classes = get_jni_classes(jni_classes_configs.get('jni_classes', []))
  jni_output_path = jni_classes_configs.get('output_dir', {})
  special_cases = jni_register_configs.get('special_cases', [])
  so_load_file_path = jni_register_configs.get('output_path', '')
  gn_file_path = gn_configs.get('output_path', '')

  # Handle special cases that are shared or not automatically generated
  excluded_java_files = {}
  special_headers = []
  special_methods = []
  for special_case in special_cases:
    header = special_case.get('header', '')
    method = special_case.get('method', '')
    java = special_case.get('java', '')
    macro = special_case.get('macro', '')
    if header != '':
      special_headers.append(['#include "' + header + '"', macro])
    if method != '':
      if type(method) == list:
        for m in method:
          special_methods.append([m, macro])
      else:
        special_methods.append([method, macro])
    if java != '':
      excluded_java_files[java] = True

  if should_skip_generate_jni(root_path, jni_configs_file, jni_classes, jni_output_path, gn_file_path, so_load_file_path, use_base_jni_utils_header):
    return 0

  # Read Java files and assemble include header and register method and gn files
  # e.g.:
  #  com/lynx/tasm/behavior/PaintingContext.java
  #    include header is #include "${path}/PaintingContext_register_jni.h"
  #    register method is RegisterJNIForPaintingContext(env)
  hash_map = {}
  include_headers = []
  register_methods = []
  gn_files = []
  for input in jni_classes:
    java_path_list = get_java_class_list(input)
    register_method_name = input.get('register_method_name', '')
    macro = input.get('macro', '')
    if register_method_name != '' and len(java_path_list) > 1:
      print(f"register_method_name cannot match multiple Java classes, please declare them separately.")
      return -1 
    for java_path in java_path_list:
      if hash_map.get(java_path):
        continue
      if excluded_java_files.get(java_path):
        continue
      java_file_full_path = os.path.join(root_path, java_path)
      if not os.path.exists(java_file_full_path):
        print(f'Error: {java_file_full_path} is not exist.')
        continue
      hash_map[java_path] = True
      java_file_name = os.path.basename(java_path)
      java_base_name = os.path.splitext(java_file_name)[0]
      if len(register_method_name) == 0:
        register_method_name = 'RegisterJNIFor' + java_base_name
      # file name
      jni_file_name = java_base_name + '_jni.h'
      jni_register_header_name = java_base_name + '_register_jni.h'
      jni_register_source_name = java_base_name + '_register_jni.cc'
      # file path
      jni_file_path = os.path.join(root_path, jni_output_path, jni_file_name).replace("\\", "/")
      jni_file_rel_path = os.path.relpath(jni_file_path, root_path).replace("\\", "/")
      jni_register_header_abs_path = os.path.join(root_path, jni_output_path, jni_register_header_name).replace("\\", "/")
      jni_register_header_rel_path = os.path.relpath(jni_register_header_abs_path, root_path).replace("\\", "/")

      include_header = f'#include "{jni_register_header_rel_path}"'
      include_headers.append([include_header, macro])
      register_method = '::'.join(get_namespace(jni_register_configs)) + f'::{register_method_name}(env);'
      register_methods.append([register_method, macro])

      gn_files.append(convert_to_relative_path(root_path, gn_file_path, jni_file_rel_path))
      gn_files.append(convert_to_relative_path(root_path, gn_file_path, jni_register_header_rel_path))

      # generate jni
      options = Options(use_base_jni_utils_header)
      print(jni_file_path)
      GenerateJNIHeader(java_file_full_path, jni_file_path, options)
      # generate register method header file
      namespace_start_str, namespace_end_str = get_namespace_guard(jni_register_configs)
      generate_register_header(java_path, 
                              register_method_name, 
                              jni_register_header_abs_path,
                              namespace_start_str,
                              namespace_end_str)

  # generate SoLoad.cc
  if not so_load_file_path:
    print(f"Warning: so load file path is empty, you can input the path by `jni_register_configs.output_path` in {jni_configs_file}.")
  else:
    include_headers.extend(special_headers)
    register_methods.sort()
    register_methods = special_methods + register_methods
    include_headers.sort()
    jni_register_configs['output_path'] = os.path.join(root_path, so_load_file_path)
    append_content_to_so_registry(jni_register_configs, include_headers, register_methods)
    gn_files.append(convert_to_relative_path(root_path, gn_file_path, so_load_file_path))
  # generate BUILD.gn
  if not gn_file_path:
    print(f"Warning: gn file path is empty, you can input the path by `gn_configs.output_path` in {jni_configs_file}.")
  else:
    gn_files.sort()
    gn_configs['output_path'] = os.path.join(root_path, gn_file_path)
    append_files_to_gn(root_path, gn_configs, gn_files)
  
  return 0

def main():
  parser = argparse.ArgumentParser()
  parser.add_argument('-path', '--jni-config-path', required=True, type=str, dest='jni_configs_file', help='Refer to `tools/build_jni/testing/jni_configs.yml` for declaration.')
  parser.add_argument('-root', '--root-dir', required=True, type=str, help='Project root dir')
  parser.add_argument('--use-base-jni-header', dest='use_base_jni_utils_header', action='store_true')
  args = parser.parse_args()
  jni_config_path = args.jni_configs_file
  root_dir = args.root_dir
  use_base_jni_utils_header = args.use_base_jni_utils_header

  return generate_files(root_dir, jni_config_path, use_base_jni_utils_header)

if __name__ == "__main__":
  sys.exit(main())
